{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "__This notebook__ computes teacher logits for editable fine-tuning on Imagenet"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "env: CUDA_VISIBLE_DEVICES=5\n"
     ]
    }
   ],
   "source": [
    "%load_ext autoreload\n",
    "%autoreload 2\n",
    "%env CUDA_VISIBLE_DEVICES=YOURDEVICEHERE\n",
    "\n",
    "imagenet_train_path = '../../imagenet/train'  # path to train ImageFolder\n",
    "logits_path = './imagenet_logits/'            # saves logits to this path\n",
    "\n",
    "import os, sys, time\n",
    "sys.path.insert(0, '..')\n",
    "import lib\n",
    "\n",
    "import numpy as np\n",
    "import torch, torch.nn as nn\n",
    "import torch.nn.functional as F\n",
    "import matplotlib.pyplot as plt\n",
    "%matplotlib inline\n",
    "\n",
    "import random\n",
    "random.seed(42)\n",
    "np.random.seed(42)\n",
    "torch.random.manual_seed(42)\n",
    "\n",
    "import time\n",
    "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torchvision.transforms as transforms\n",
    "import torchvision.datasets as datasets\n",
    "import torchvision.models as models\n",
    "\n",
    "normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],\n",
    "                                 std=[0.229, 0.224, 0.225])\n",
    "\n",
    "train_dataset = datasets.ImageFolder(\n",
    "    imagenet_train_path,\n",
    "    transforms.Compose([\n",
    "        transforms.RandomResizedCrop(224),\n",
    "        transforms.RandomHorizontalFlip(),\n",
    "        transforms.ToTensor(),\n",
    "        normalize,\n",
    "    ]))\n",
    "\n",
    "batch_size = 64\n",
    "\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "    train_dataset, batch_size=batch_size, shuffle=False,\n",
    "    num_workers=1, pin_memory=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from tqdm import tqdm_notebook, tnrange\n",
    "from IPython.display import clear_output\n",
    "\n",
    "import torchvision\n",
    "\n",
    "image_index = 0\n",
    "\n",
    "def get_logits(model, X_test):\n",
    "    with lib.training_mode(model, is_train=False):\n",
    "        return lib.eval_func(lib.Lambda(lambda x: model(x.to(device))),\n",
    "                                        X_test, device='cuda', batch_size=64)\n",
    "\n",
    "model = torchvision.models.resnet18(pretrained=True).to(device)\n",
    "\n",
    "for batch in tqdm_notebook(train_loader):\n",
    "    logits = get_logits(model, batch[0])\n",
    "    for logit in logits:\n",
    "        image_name = train_loader.dataset.imgs[image_index][0]\n",
    "        image_name = \"{}/{}\".format(logits_path, image_name[image_name.rfind(\"/\") + 1:image_name.find(\".j\")])\n",
    "        np.save(image_name, torch.softmax(logits, dim=-1))\n",
    "        image_index += 1"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "9.9G\timagenet_logits\r\n"
     ]
    }
   ],
   "source": [
    "!du -h imagenet_logits"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
